{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# MNIST 数字识别"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {},
   "outputs": [],
   "source": [
    "import torch\n",
    "import torch.nn\n",
    "import torch.optim\n",
    "import torchvision.datasets\n",
    "import torchvision.transforms\n",
    "import torch.utils.data"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "读取数据"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Downloading http://yann.lecun.com/exdb/mnist/train-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/train-labels-idx1-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-images-idx3-ubyte.gz\n",
      "Downloading http://yann.lecun.com/exdb/mnist/t10k-labels-idx1-ubyte.gz\n",
      "Processing...\n",
      "Done!\n"
     ]
    }
   ],
   "source": [
    "train_dataset = torchvision.datasets.MNIST(root='./data/mnist',\n",
    "        train=True, transform=torchvision.transforms.ToTensor(),\n",
    "        download=True)\n",
    "test_dataset = torchvision.datasets.MNIST(root='./data/mnist',\n",
    "        train=False, transform=torchvision.transforms.ToTensor(),\n",
    "        download=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "len(train_loader) = 600\n",
      "len(test_loader) = 100\n",
      "image.size() = torch.Size([100, 1, 28, 28])\n",
      "labels.size() = torch.Size([100])\n"
     ]
    }
   ],
   "source": [
    "batch_size = 100\n",
    "train_loader = torch.utils.data.DataLoader(\n",
    "        dataset=train_dataset, batch_size=batch_size)\n",
    "print('len(train_loader) = {}'.format(len(train_loader)))\n",
    "test_loader = torch.utils.data.DataLoader(\n",
    "        dataset=test_dataset, batch_size=batch_size)\n",
    "print('len(test_loader) = {}'.format(len(test_loader)))\n",
    "\n",
    "for images, labels in train_loader:\n",
    "    print ('image.size() = {}'.format(images.size()))\n",
    "    print ('labels.size() = {}'.format(labels.size()))\n",
    "    break"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "Text(0.5,1,'label = 5')"
      ]
     },
     "execution_count": 4,
     "metadata": {},
     "output_type": "execute_result"
    },
    {
     "data": {
      "image/png": "iVBORw0KGgoAAAANSUhEUgAAAP8AAAEICAYAAACQ6CLfAAAABHNCSVQICAgIfAhkiAAAAAlwSFlzAAALEgAACxIB0t1+/AAAADl0RVh0U29mdHdhcmUAbWF0cGxvdGxpYiB2ZXJzaW9uIDIuMi4yLCBodHRwOi8vbWF0cGxvdGxpYi5vcmcvhp/UCwAAEC5JREFUeJzt3X2sVHV+x/H3R9S0PiKxomVFFmtx1Vh2g9i6ZtVYVjEavT5spDWx0YrtSqNNS2r5R02LtfWh1WhdMOpCsouaqgXpdtGKil0b4lVRWVx3rUFFb2ENIuBj4H77xz3sXvHOb+bOnJkz9/4+r2QyZ+Z7zpwvEz73nJlzzvwUEZhZfvaougEzq4bDb5Yph98sUw6/WaYcfrNMOfxmmXL4RyBJ6yX9YYPzhqTfaXI9TS9r3c/ht46S9LSkTyVtL26vV91Trhx+q8KciNivuE2puplcOfwjnKTpkv5H0hZJfZLulLT3brOdJelNSe9LulnSHoOWv0zSa5I+kLRC0hEd/idYRRz+kW8n8JfAwcAfAKcD391tnh5gGvAN4FzgMgBJ5wHzgPOB3wKeBZY0slJJ/1r8wRnq9kqdxf+h+EP0E0mnNvSvtNLJ5/aPPJLWA38aEf81RO0a4JSI6CkeBzAzIn5cPP4ucEFEnC7pP4F/i4h7i9oewHbgaxHxVrHsURHxRom9nwisAz4HLgbuBKZGxP+WtQ5rjLf8I5yk35W0XNL/SdoK3MjAXsBg7wyafgv47WL6COD2XVtsYDMgYEK7+o2I1RGxLSI+i4hFwE+As9q1PqvN4R/57gZ+xsAW+gAGduO12zyHD5qeCLxXTL8DXBkRYwfdfjMinqu3UknfG/SN/e63nw6j/xiiX+sAh3/k2x/YCmyXdDTw50PMM1fSQZIOB64GHiye/x7wt5KOBZB0oKSLGllpRPzZoG/sd78dO9QyksZKOkPSb0jaU9IfA98CVgzvn2xlcPhHvr8G/gjYBtzDr4M92FLgBWAN8B/AvQAR8Sjwj8ADxUeGtcDMNva6F/D3wC+B94G/AM6LCB/rr4C/8DPLlLf8Zply+M0y5fCbZcrhN8vUnp1cWXHGmJm1UUQ0dN5ES1t+SWdKel3SG5KubeW1zKyzmj7UJ2kM8HNgBrABeB6YFRHrEst4y2/WZp3Y8k8H3oiINyPic+ABBq4YM7MRoJXwT+CLF4xsYIgLQiTNltQrqbeFdZlZyVr5wm+oXYsv7dZHxEJgIXi336ybtLLl38AXrxb7Cr++WszMulwr4X8eOErSV4ufjboYWFZOW2bWbk3v9kfEDklzGLgccwxwX0QM5zpuM6tQR6/q82d+s/bryEk+ZjZyOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y1TTQ3TbyDBmzJhk/cADD2zr+ufMmVOzts8++ySXnTJlSrJ+1VVXJeu33HJLzdqsWbOSy3766afJ+k033ZSs33DDDcl6N2gp/JLWA9uAncCOiJhWRlNm1n5lbPlPi4j3S3gdM+sgf+Y3y1Sr4Q/gcUkvSJo91AySZkvqldTb4rrMrESt7vZ/MyLek3QI8ISkn0XEqsEzRMRCYCGApGhxfWZWkpa2/BHxXnG/CXgUmF5GU2bWfk2HX9K+kvbfNQ18G1hbVmNm1l6t7PaPBx6VtOt1fhgRPy6lq1Fm4sSJyfree++drJ900knJ+sknn1yzNnbs2OSyF1xwQbJepQ0bNiTrd9xxR7Le09NTs7Zt27bksi+//HKy/swzzyTrI0HT4Y+IN4HfK7EXM+sgH+ozy5TDb5Yph98sUw6/WaYcfrNMKaJzJ92N1jP8pk6dmqyvXLkyWW/3ZbXdqr+/P1m/7LLLkvXt27c3ve6+vr5k/YMPPkjWX3/99abX3W4RoUbm85bfLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUj/OXYNy4ccn66tWrk/XJkyeX2U6p6vW+ZcuWZP20006rWfv888+Ty+Z6/kOrfJzfzJIcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5YpD9Fdgs2bNyfrc+fOTdbPPvvsZP2ll15K1uv9hHXKmjVrkvUZM2Yk6x999FGyfuyxx9asXX311cllrb285TfLlMNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXr+bvAAQcckKzXG056wYIFNWuXX355ctlLLrkkWV+yZEmybt2ntOv5Jd0naZOktYOeGyfpCUm/KO4PaqVZM+u8Rnb7vw+cudtz1wJPRsRRwJPFYzMbQeqGPyJWAbufv3ousKiYXgScV3JfZtZmzZ7bPz4i+gAiok/SIbVmlDQbmN3kesysTdp+YU9ELAQWgr/wM+smzR7q2yjpMIDiflN5LZlZJzQb/mXApcX0pcDSctoxs06pu9svaQlwKnCwpA3AdcBNwEOSLgfeBi5qZ5Oj3datW1ta/sMPP2x62SuuuCJZf/DBB5P1/v7+ptdt1aob/oiYVaN0esm9mFkH+fRes0w5/GaZcvjNMuXwm2XK4TfLlC/pHQX23XffmrXHHnssuewpp5ySrM+cOTNZf/zxx5N16zwP0W1mSQ6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y5SP849yRx55ZLL+4osvJutbtmxJ1p966qlkvbe3t2btrrvuSi7byf+bo4mP85tZksNvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXj/Jnr6elJ1u+///5kff/992963fPmzUvWFy9enKz39fU1ve7RzMf5zSzJ4TfLlMNvlimH3yxTDr9Zphx+s0w5/GaZ8nF+SzruuOOS9dtuuy1ZP/305gdzXrBgQbI+f/78ZP3dd99tet0jWWnH+SXdJ2mTpLWDnrte0ruS1hS3s1pp1sw6r5Hd/u8DZw7x/D9HxNTi9qNy2zKzdqsb/ohYBWzuQC9m1kGtfOE3R9IrxceCg2rNJGm2pF5JtX/Mzcw6rtnw3w0cCUwF+oBba80YEQsjYlpETGtyXWbWBk2FPyI2RsTOiOgH7gGml9uWmbVbU+GXdNighz3A2lrzmll3qnucX9IS4FTgYGAjcF3xeCoQwHrgyoioe3G1j/OPPmPHjk3WzznnnJq1er8VIKUPV69cuTJZnzFjRrI+WjV6nH/PBl5o1hBP3zvsjsysq/j0XrNMOfxmmXL4zTLl8JtlyuE3y5Qv6bXKfPbZZ8n6nnumD0bt2LEjWT/jjDNq1p5++unksiOZf7rbzJIcfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5apulf1Wd6OP/74ZP3CCy9M1k844YSatXrH8etZt25dsr5q1aqWXn+085bfLFMOv1mmHH6zTDn8Zply+M0y5fCbZcrhN8uUj/OPclOmTEnW58yZk6yff/75yfqhhx467J4atXPnzmS9ry/9a/H9/f1ltjPqeMtvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXwm2Wq7nF+SYcDi4FDgX5gYUTcLmkc8CAwiYFhur8TER+0r9V81TuWPmvWUAMpD6h3HH/SpEnNtFSK3t7eZH3+/PnJ+rJly8psJzuNbPl3AH8VEV8Dfh+4StIxwLXAkxFxFPBk8djMRoi64Y+Ivoh4sZjeBrwGTADOBRYVsy0CzmtXk2ZWvmF95pc0Cfg6sBoYHxF9MPAHAjik7ObMrH0aPrdf0n7Aw8A1EbFVamg4MCTNBmY3156ZtUtDW35JezEQ/B9ExCPF0xslHVbUDwM2DbVsRCyMiGkRMa2Mhs2sHHXDr4FN/L3AaxFx26DSMuDSYvpSYGn57ZlZu9QdolvSycCzwKsMHOoDmMfA5/6HgInA28BFEbG5zmtlOUT3+PHjk/VjjjkmWb/zzjuT9aOPPnrYPZVl9erVyfrNN99cs7Z0aXp74Utym9PoEN11P/NHxH8DtV7s9OE0ZWbdw2f4mWXK4TfLlMNvlimH3yxTDr9Zphx+s0z5p7sbNG7cuJq1BQsWJJedOnVqsj558uSmeirDc889l6zfeuutyfqKFSuS9U8++WTYPVlneMtvlimH3yxTDr9Zphx+s0w5/GaZcvjNMuXwm2Uqm+P8J554YrI+d+7cZH369Ok1axMmTGiqp7J8/PHHNWt33HFHctkbb7wxWf/oo4+a6sm6n7f8Zply+M0y5fCbZcrhN8uUw2+WKYffLFMOv1mmsjnO39PT01K9FevWrUvWly9fnqzv2LEjWU9dc79ly5bkspYvb/nNMuXwm2XK4TfLlMNvlimH3yxTDr9Zphx+s0wpItIzSIcDi4FDgX5gYUTcLul64Argl8Ws8yLiR3VeK70yM2tZRKiR+RoJ/2HAYRHxoqT9gReA84DvANsj4pZGm3L4zdqv0fDXPcMvIvqAvmJ6m6TXgGp/usbMWjasz/ySJgFfB1YXT82R9Iqk+yQdVGOZ2ZJ6JfW21KmZlarubv+vZpT2A54B5kfEI5LGA+8DAfwdAx8NLqvzGt7tN2uz0j7zA0jaC1gOrIiI24aoTwKWR8RxdV7H4Tdrs0bDX3e3X5KAe4HXBge/+CJwlx5g7XCbNLPqNPJt/8nAs8CrDBzqA5gHzAKmMrDbvx64svhyMPVa3vKbtVmpu/1lcfjN2q+03X4zG50cfrNMOfxmmXL4zTLl8JtlyuE3y5TDb5Yph98sUw6/WaYcfrNMOfxmmXL4zTLl8JtlyuE3y1Snh+h+H3hr0OODi+e6Ubf21q19gXtrVpm9HdHojB29nv9LK5d6I2JaZQ0kdGtv3doXuLdmVdWbd/vNMuXwm2Wq6vAvrHj9Kd3aW7f2Be6tWZX0VulnfjOrTtVbfjOriMNvlqlKwi/pTEmvS3pD0rVV9FCLpPWSXpW0purxBYsxEDdJWjvouXGSnpD0i+J+yDESK+rteknvFu/dGklnVdTb4ZKekvSapJ9Kurp4vtL3LtFXJe9bxz/zSxoD/ByYAWwAngdmRcS6jjZSg6T1wLSIqPyEEEnfArYDi3cNhSbpn4DNEXFT8YfzoIj4my7p7XqGOWx7m3qrNaz8n1Dhe1fmcPdlqGLLPx14IyLejIjPgQeAcyvoo+tFxCpg825PnwssKqYXMfCfp+Nq9NYVIqIvIl4sprcBu4aVr/S9S/RViSrCPwF4Z9DjDVT4BgwhgMclvSBpdtXNDGH8rmHRivtDKu5nd3WHbe+k3YaV75r3rpnh7stWRfiHGkqom443fjMivgHMBK4qdm+tMXcDRzIwhmMfcGuVzRTDyj8MXBMRW6vsZbAh+qrkfasi/BuAwwc9/grwXgV9DCki3ivuNwGPMvAxpZts3DVCcnG/qeJ+fiUiNkbEzojoB+6hwveuGFb+YeAHEfFI8XTl791QfVX1vlUR/ueBoyR9VdLewMXAsgr6+BJJ+xZfxCBpX+DbdN/Q48uAS4vpS4GlFfbyBd0ybHutYeWp+L3rtuHuKznDrziU8S/AGOC+iJjf8SaGIGkyA1t7GLjc+YdV9iZpCXAqA5d8bgSuA/4deAiYCLwNXBQRHf/irUZvpzLMYdvb1FutYeVXU+F7V+Zw96X049N7zfLkM/zMMuXwm2XK4TfLlMNvlimH3yxTDr9Zphx+s0z9PwewAQg/CrbEAAAAAElFTkSuQmCC\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {},
     "output_type": "display_data"
    }
   ],
   "source": [
    "%matplotlib inline\n",
    "import matplotlib.pyplot as plt\n",
    "plt.imshow(images[0, 0], cmap='gray')\n",
    "plt.title('label = {}'.format(labels[0]))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "训练"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "第0趟第0批：loss = 2.30195\n",
      "第0趟第100批：loss = 0.752388\n",
      "第0趟第200批：loss = 0.688536\n",
      "第0趟第300批：loss = 0.567242\n",
      "第0趟第400批：loss = 0.405684\n",
      "第0趟第500批：loss = 0.438673\n",
      "第1趟第0批：loss = 0.336651\n",
      "第1趟第100批：loss = 0.349938\n",
      "第1趟第200批：loss = 0.426128\n",
      "第1趟第300批：loss = 0.35121\n",
      "第1趟第400批：loss = 0.3013\n",
      "第1趟第500批：loss = 0.366111\n",
      "第2趟第0批：loss = 0.269673\n",
      "第2趟第100批：loss = 0.302057\n",
      "第2趟第200批：loss = 0.361562\n",
      "第2趟第300批：loss = 0.297809\n",
      "第2趟第400批：loss = 0.273201\n",
      "第2趟第500批：loss = 0.341415\n",
      "第3趟第0批：loss = 0.239015\n",
      "第3趟第100批：loss = 0.282269\n",
      "第3趟第200批：loss = 0.327436\n",
      "第3趟第300批：loss = 0.27465\n",
      "第3趟第400批：loss = 0.260817\n",
      "第3趟第500批：loss = 0.327343\n",
      "第4趟第0批：loss = 0.221851\n",
      "第4趟第100批：loss = 0.272043\n",
      "第4趟第200批：loss = 0.305498\n",
      "第4趟第300批：loss = 0.261793\n",
      "第4趟第400批：loss = 0.25385\n",
      "第4趟第500批：loss = 0.317882\n"
     ]
    }
   ],
   "source": [
    "fc = torch.nn.Linear(28 * 28, 10)\n",
    "criterion = torch.nn.CrossEntropyLoss()\n",
    "optimizer = torch.optim.Adam(fc.parameters()) \n",
    "\n",
    "num_epochs = 5\n",
    "for epoch in range(num_epochs):\n",
    "    for idx, (images, labels) in enumerate(train_loader):\n",
    "        x = images.reshape(-1, 28*28)\n",
    "        \n",
    "        optimizer.zero_grad()\n",
    "        preds = fc(x)\n",
    "        loss = criterion(preds, labels)\n",
    "        loss.backward()\n",
    "        optimizer.step()\n",
    "        \n",
    "        if idx % 100 == 0:\n",
    "            print('第{}趟第{}批：loss = {:g}'.format(epoch, idx, loss))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "测试"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "测试集上的准确率：92.2%\n"
     ]
    }
   ],
   "source": [
    "correct = 0\n",
    "total = 0\n",
    "for images, labels in test_loader:\n",
    "    x = images.reshape(-1, 28 * 28)\n",
    "    preds = fc(x)\n",
    "    predicted = torch.argmax(preds, 1)\n",
    "    total += labels.size(0)\n",
    "    correct += (predicted == labels).sum().item()\n",
    "    \n",
    "accuracy = correct / total\n",
    "print('测试集上的准确率：{:.1%}'.format(accuracy))"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
